# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hysop.backend.device.opencl.opencl_types import OpenClTypeGen
from hysop.tools.htypes import check_instance
import sympy as sm
from packaging import version
if version.parse(sm.__version__) > version.parse("1.7"):
from sympy.printing.c import C99CodePrinter
else:
from sympy.printing.ccode import C99CodePrinter
# /!\ TODO complete known_functions list with OpenCL builtins
# - keys are sympy function names (beware to capital letters)
# - values are either strings or list of tuples (predicate(inputs),string)
# that corresponds to OpenCL function builtins.
# Here are some attributes that can be checked in predicates:
# is_zero
# is_finite is_integer
# is_negative is_positive
# is_rational is_real
known_functions = {
"Abs": [(lambda x: x.is_integer, "abs"), "fabs"],
"min": [(lambda x, y: x.is_integer and y.is_integer, "min"), "fmin"],
"max": [(lambda x, y: x.is_integer and y.is_integer, "max"), "fmax"],
"sqrt": "sqrt",
"gamma": "tgamma",
"sin": "sin",
"cos": "cos",
"tan": "tan",
"asin": "asin",
"acos": "acos",
"atan": "atan",
"atan2": "atan2",
"sinh": "sinh",
"cosh": "cosh",
"tanh": "tanh",
"asinh": "asinh",
"acosh": "acosh",
"atanh": "atanh",
"exp": "exp",
"log": "log",
"erf": "erf",
"floor": "floor",
"ceiling": "ceil",
}
# OpenCl 2.2 reserved keywords (see opencl documentation)
reserved_words = [
# C++14 keywords
"alignas",
"continue",
"friend",
"register",
"true",
"alignof",
"decltype",
"goto",
"reinterpret_cast",
"try",
"asm",
"default",
"if",
"return",
"typedef",
"auto",
"delete",
"inline",
"short",
"typeid",
"bool",
"do",
"int",
"signed",
"typename",
"break",
"double",
"long",
"sizeof",
"union",
"case",
"dynamic_cast",
"mutable",
"static",
"unsigned",
"catch",
"else",
"namespace",
"static_assert",
"using",
"char",
"enum",
"new",
"static_cast",
"virtual",
"char16_t",
"explicit",
"noexcept",
"struct",
"void",
"char32_t",
"export",
"nullptr",
"switch",
"volatile",
"class",
"extern",
"operator",
"template",
"wchar_t",
"const",
"false",
"private",
"this",
"while",
"constexpr",
"float",
"protected",
"thread_local",
"const_cast",
"for",
"public",
"throw" "override",
"final",
# OpenCl data types
"uchar",
"ushort",
"uint",
"ulong",
"half",
"bool2",
"char2",
"uchar2",
"short2",
"ushort2",
"int2",
"uint2",
"long2",
"ulong2",
"half2",
"float2",
"double2",
"bool3",
"char3",
"uchar3",
"short3",
"ushort3",
"int3",
"uint3",
"long3",
"ulong3",
"half3",
"float3",
"double3",
"bool4",
"char4",
"uchar4",
"short4",
"ushort4",
"int4",
"uint4",
"long4",
"ulong4",
"half4",
"float4",
"double4",
"bool8",
"char8",
"uchar8",
"short8",
"ushort8",
"int8",
"uint8",
"long8",
"ulong8",
"half8",
"float8",
"double8",
"bool16",
"char16",
"uchar16",
"short16",
"ushort16",
"int16",
"uint16",
"long16",
"ulong16",
"half16",
"float16",
"double16",
# function qualifiers
"kernel",
"__kernel",
# access qualifiers
"read_only",
"write_only",
"read_write",
"__read_only",
"__write_only",
"__read_write",
]
[docs]
class OpenClPrinter(C99CodePrinter):
"""
A printer to convert sympy expressions to strings of opencl code
"""
printmethod = "_clcode"
language = "OpenCL"
_default_settings = {
"order": None,
"full_prec": "auto",
"precision": None,
"user_functions": {},
"human": True,
"contract": True,
"dereference": set(),
"error_on_reserved": True,
"reserved_word_suffix": None,
}
def __init__(self, typegen, symbol2vars=None, **settings):
check_instance(typegen, OpenClTypeGen)
check_instance(symbol2vars, dict, keys=sm.Symbol, allow_none=True)
super().__init__(settings=settings)
self.known_functions = dict(known_functions)
self.reserved_words = set(reserved_words)
self.typegen = typegen
self.symbol2vars = symbol2vars
[docs]
def dump_symbol(self, expr):
symbol2vars = self.symbol2vars
if expr in symbol2vars:
return self._print(symbol2vars[expr])
else:
return super()._print_Symbol(expr)
[docs]
def dump_rational(self, expr):
return self.typegen.dump(expr)
[docs]
def dump_float(self, expr):
return self.typegen.dump(expr)
def _print_Symbol(self, expr):
return self.dump_symbol(expr)
def _print_Rational(self, expr):
return self.dump_rational(expr)
def _print_PythonRational(self, expr):
return self.dump_rational(expr)
def _print_Fraction(self, expr):
return self.dump_rational(expr)
def _print_mpq(self, expr):
return self.dump_rational(expr)
def _print_Float(self, expr):
return self.dump_float(expr)
# last resort printer (if _print_CLASS is not found)
[docs]
def emptyPrinter(self, expr):
return self.typegen.dump(expr)
[docs]
def dump_clcode(expr, typegen, **kargs):
"""Return OpenCL representation of the given expression as a string."""
p = OpenClPrinter(typegen=typegen, **kargs)
s = p.doprint(expr)
return s
[docs]
def print_clcode(expr, typegen, **kargs):
"""Prints OpenCL representation of the given expression."""
print(dump_clcode(expr, typegen=typegen, **kargs))